import sys
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import torchvision.models as models
import torch.utils.data as Data
import random
import os
from scipy.io import loadmat
from copy import deepcopy
from model.models import *
from model.resnet import *
from model.vgg import *
from model.googlenet import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # device
data_aug = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ]) 
#download dataset
def extract_data(data_type):
    if data_type=="mnist":
        train_dataset = dsets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
        test_dataset = dsets.MNIST(root='./data/', train=False, transform=transforms.ToTensor())
    elif data_type=="fmnist":
        train_dataset = dsets.FashionMNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
        test_dataset = dsets.FashionMNIST(root='./data/', train=False, transform=transforms.ToTensor())
    elif data_type=="kmnist":
        train_dataset = dsets.KMNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
        test_dataset = dsets.KMNIST(root='./data/', train=False, transform=transforms.ToTensor())
    elif data_type=="cifar10":
        train_dataset = dsets.CIFAR10(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
        test_dataset = dsets.CIFAR10(root='./data/', train=False, transform=transforms.ToTensor())
    if data_type=="svhn":
        data = loadmat(os.path.join('./data/', 'train_32x32.mat'))
        train_X = data['X'].transpose(3, 2, 0, 1)
        train_Y = data['y'].flatten()-1
        data = loadmat(os.path.join('./data/', 'test_32x32.mat'))
        test_X = data['X'].transpose(3, 2, 0, 1)
        test_Y = data['y'].flatten()-1
        train_X=torch.from_numpy(train_X)
        train_X=train_X.float()
        test_X=torch.from_numpy(test_X)
        test_X=test_X.float()
        train_Y=torch.LongTensor (train_Y)
        test_Y=torch.LongTensor (test_Y)
    elif data_type!="svhn":
        train_full_loader = Data.DataLoader(dataset=train_dataset, batch_size=len(train_dataset), shuffle=True, num_workers=8)
        test_full_loader = Data.DataLoader(dataset=test_dataset, batch_size=len(test_dataset), shuffle=True, num_workers=8)
        train_X, train_Y = next(iter(train_full_loader))
        if data_type=="mnist" or data_type=="fmnist" or data_type=="kmnist":
            train_X = train_X.view((train_X.shape[0], -1)).float()
        test_X, test_Y = next(iter(test_full_loader))
    train_Y = binarize_class(train_Y)
    test_Y = binarize_class(test_Y)
    yield train_X, train_Y, test_X, test_Y
    
def binarize_class(y): # one-hot label
    label = y.reshape(len(y), -1)
    y = torch.zeros(label.size(0), 10)
    idx = torch.LongTensor(label).view(-1, 1)
    y.scatter_(dim = 1, index = idx, value = 1)
    return y 

#Min k setting
def generate_mincompl_labels(train_X, train_Y, model, rate=0.4, batch_size=256,label_num=3,data_type="mnist",pre_model="resnet18"):
    with torch.no_grad():
        model = model.to(device)
        if data_type=="mnist":
            model.load_state_dict(torch.load('./check/pretrain/pre_mnist.pt',map_location='cuda:0'))
        elif data_type=="fmnist":
            model.load_state_dict(torch.load('./check/pretrain/pre_fmnist.pt',map_location='cuda:0'))
        elif data_type=="kmnist":
            model.load_state_dict(torch.load('./check/pretrain/pre_kmnist.pt',map_location='cuda:0'))
        elif data_type=="cifar10":
            if pre_model == 'resnet18':
                model.load_state_dict(torch.load('./check/pretrain/cifar_resnet18().pt',map_location='cuda:0'))
            elif pre_model == 'resnet34':
                model.load_state_dict(torch.load('./check/pretrain/cifar_resnet34().pt',map_location='cuda:0'))
            elif pre_model == 'vgg16':
                model.load_state_dict(torch.load('./check/pretrain/cifar_vgg16().pt',map_location='cuda:0'))
            elif pre_model == 'googlenet':
                model.load_state_dict(torch.load('./check/pretrain/cifar_googlenet().pt',map_location='cuda:0'))
        elif data_type=="svhn":
            model.load_state_dict(torch.load('./check/pretrain/svhn.pt',map_location='cuda:0'))
        avg_C = 0
        train_X, train_Y = train_X.to(device), train_Y.to(device)
        train_p_Y_list = []
        train_p_Y = []
        step = train_X.size(0) // batch_size
        for i in range(0, step+1):
            outputs = model(train_X[i*batch_size:(i+1)*batch_size])
            train_p_Y = train_Y[i*batch_size:(i+1)*batch_size].clone().detach()
            a=torch.nonzero(train_p_Y==1)
            labels=a[:,1]
            K = torch.max(labels)+1 
            candidates = np.arange(int(K.item()))
            candidates = np.repeat(candidates.reshape(1, int(K.item())), len(labels), 0)
            com_array = F.softmax(outputs, dim=1).clone().detach()
            com_array[torch.where(train_Y[i*batch_size:(i+1)*batch_size]==1)] = 0
            com_array = com_array / torch.max(com_array, dim=1, keepdim=True)[0]
            com_array = com_array / com_array.mean(dim=1, keepdim=True) * rate
            mask = np.ones((len(labels), int(K.item())), dtype=bool)  # mask: (len(labels), K)
            mask[range(len(labels)), labels.cpu().numpy()] = False
            candidates_ = candidates[mask].reshape(len(labels), K-1) 
            com_array= com_array[mask].reshape(len(labels), K-1) 
            a, idx = torch.sort(com_array, descending=False)
            lists=idx[:,:label_num].cpu().numpy()
            idx2 = np.random.randint(0, label_num, len(labels))
            idx=com_array.min(dim=1, keepdim=True)[1]
            complementary_labels = candidates_[range(len(labels)),np.choose(idx2, lists.T)]
            train_p_Y_list.append(torch.from_numpy(complementary_labels))
        train_p_Y = torch.cat(train_p_Y_list, dim=0)
        assert train_p_Y.shape[0] == train_X.shape[0]
    return train_p_Y.cpu()

#Uniform setting
def generate_compl_labels(labels):
    # args, labels: ordinary labels
    labels_Y = labels.to(device)
    a=torch.nonzero(labels==1)
    labels=a[:,1]
    K = torch.max(labels)+1
    candidates = np.arange(K)
    candidates = np.repeat(candidates.reshape(1, int(K.item())), len(labels), 0)
    mask = np.ones((len(labels), int(K.item())), dtype=bool)  # mask: (len(labels), K)
    mask[range(len(labels)), labels.cpu().numpy()] = False
    candidates_ = candidates[mask].reshape(len(labels), K-1)  # this is the candidates without true class
    idx = np.random.randint(0, K-1, len(labels))
    complementary_labels = candidates_[np.arange(len(labels)), np.array(idx)]
    return torch.from_numpy(complementary_labels).cpu()
def class_prior(complementary_labels):
    return np.bincount(complementary_labels) / len(complementary_labels)
def load_instance_dependent_dataloader(batch_size,data_type,label_num,pre_model):
    train_X, train_Y, test_X, test_Y = next(extract_data(data_type))
    num_features = 28 * 28
    if data_type=="mnist" or data_type=="fmnist" or data_type=="kmnist":
        complementary_net = mlp_model(input_dim=28*28, hidden_dim=500, output_dim=10)
    elif data_type=="cifar10" or data_type=="svhn":
        if pre_model == 'resnet18':
            complementary_net = resnet18().to(device)
        elif pre_model == 'resnet34':
            complementary_net = resnet34().to(device)
        elif pre_model == 'vgg16':
            complementary_net = VGGNet().to(device)
        elif pre_model == 'googlenet':
            complementary_net = googlenet().to(device)
    mincom= generate_mincompl_labels(train_X=train_X,train_Y=train_Y,model=complementary_net,batch_size=batch_size,label_num=label_num,data_type=data_type)
    com= generate_compl_labels(labels=train_Y)
    train_Y=torch.topk(torch.tensor(train_Y), 1)[1].squeeze(1).clone().detach()
    test_Y=torch.topk(torch.tensor(test_Y), 1)[1].squeeze(1).clone().detach()
    ccp_com = class_prior(com)
    ccp_mincom = class_prior(mincom)
    data = torch.utils.data.TensorDataset(train_X,train_Y,mincom,com)
    loader = torch.utils.data.DataLoader(dataset = data, batch_size = batch_size, shuffle = True)
    train_data = torch.utils.data.TensorDataset(train_X,train_Y)
    train_loader = torch.utils.data.DataLoader(dataset = train_data, batch_size = batch_size, shuffle = True)
    test_data = torch.utils.data.TensorDataset(test_X, test_Y)
    test_loader = torch.utils.data.DataLoader(dataset = test_data, batch_size = batch_size, shuffle = True)
    return loader,train_loader, test_loader,ccp_com,ccp_mincom